from enum import Enum
from typing import List, Optional, Union
from pydantic import BaseModel, Field
from .function_call import ChatCompletionFunctionCall
from abc import ABCMeta
import logging

logger = logging.getLogger(__name__)

__all__ = [
    "ChatCompletionRole",
    "ChatCompletionMessage",
    "ChatCompletionSystemMessage",
    "ChatCompletionUserMessage",
    "ChatCompletionAssistantMessage",
    "ChatCompletionBaseContent",
    "ChatCompletionTextContent",
    "ChatCompletionImageUrlContent",
    "ChatCompletionMultimodalContent",
    "ChatCompletionFunctionMessage",
    "ChatCompletionFinishReason",
    "ChatCompletion",
    "ChatCompletionChunk",
    "is_assistant_text_message",
    "is_assistant_function_calls_message",
    "ChatCompletionUsage",
]


class ChatCompletionRole(str, Enum):
    system = "system"
    assistant = "assistant"
    user = "user"
    function = "function"


class ChatCompletionContentType(str, Enum):
    TEXT = "text"
    IMAGE_URL = "image_url"


class ChatCompletionBaseContent(BaseModel):
    pass


class ChatCompletionTextContent(ChatCompletionBaseContent):
    type: str = Field(
        "text",
        Literal="text",
        description="The type of the content.",
    )
    text: str = Field(
        ...,
        description="The text content.",
    )


class ChatCompletionImageUrl(BaseModel):
    url: str = Field(
        ...,
        description="The url of the image.",
    )
    # details: Optional[str] = Field(
    #     #     None,
    #     #     description="The details of the image.",
    #     # )


class ChatCompletionImageUrlContent(ChatCompletionBaseContent):
    type: str = Field(
        "image_url",
        Literal="image_url",
        description="The type of the content.",
    )
    image_url: ChatCompletionImageUrl = Field(
        ...,
        description="The image url content.",
    )


ChatCompletionMultimodalContent = List[Union[ChatCompletionTextContent, ChatCompletionImageUrlContent]]


# Base message class
class ChatCompletionMessage(BaseModel, metaclass=ABCMeta):
    content: Optional[Union[str, ChatCompletionMultimodalContent]] = Field(
        None,
        description="The content of the message. For vision LLMs, it can be a list of text and image_url.",
    )


# Subclasses for each message type with specific role
class ChatCompletionSystemMessage(ChatCompletionMessage):
    role: ChatCompletionRole = Field(
        ChatCompletionRole.system,
        Literal=ChatCompletionRole.system,
        description="The role of the message, which is always 'system' for a system message",
    )


class ChatCompletionUserMessage(ChatCompletionMessage):
    role: ChatCompletionRole = Field(
        ChatCompletionRole.user,
        Literal=ChatCompletionRole.user,
        description="The role of the message, which is always 'user' for a user message",
    )


class ChatCompletionAssistantMessage(ChatCompletionMessage):
    role: ChatCompletionRole = Field(
        ChatCompletionRole.assistant,
        Literal=ChatCompletionRole.assistant,
        description="The role of the message, which is always 'assistant' for an assistant message",
    )
    function_calls: Optional[List[ChatCompletionFunctionCall]] = Field(
        None,
        description="The function calls requested by the assistant.",
    )


class ChatCompletionFunctionMessage(ChatCompletionMessage):
    role: ChatCompletionRole = Field(
        ChatCompletionRole.function,
        Literal=ChatCompletionRole.function,
        description="The role of the message, which is always 'function' for a function message",
    )
    id: str = Field(
        ...,
        description="The corresponding id of the tool requested by the assistant.",
    )


class ChatCompletionFinishReason(str, Enum):
    stop = "stop"
    length = "length"
    function_calls = "function_calls"
    error = "error"
    recitation = "recitation"
    unknown = "unknown"
    # todo: add content_filter


class ChatCompletionUsage(BaseModel):
    input_tokens: int = Field(
        ...,
        description="The number of tokens in the input.",
    )
    output_tokens: int = Field(
        ...,
        description="The number of tokens in the output.",
    )


class ChatCompletion(BaseModel):
    object: str = Field(
        "ChatCompletion",
        Literal="ChatCompletion",
        description="The object type, which is always 'ChatCompletion'.",
    )
    finish_reason: ChatCompletionFinishReason = Field(
        ...,
        description="The reason why the generation is finished.",
    )
    message: ChatCompletionAssistantMessage = Field(
        ...,
        description="The message generated by the assistant.",
    )
    created_timestamp: int = Field(
        ...,
        description="The timestamp in milliseconds when the response is created.",
    )
    usage: ChatCompletionUsage = Field(
        ...,
        description="The token usage of the response.",
    )
    fallback_index: Optional[int] = Field(
        None,
        description="The index of the fallback model used.",
    )


class ChatCompletionChunk(BaseModel):
    object: str = Field(
        "ChatCompletionChunk",
        Literal="ChatCompletionChunk",
        description="The object type, which is always 'ChatCompletionChunk'.",
    )
    role: ChatCompletionRole = Field(
        ChatCompletionRole.assistant,
        Literal=ChatCompletionRole.assistant,
        description="The role of the chunk.",
    )
    index: int = Field(
        ...,
        description="The index of the chunk in the message.",
    )
    delta: str = Field(
        ...,
        description="The delta content generated by the streaming inference.",
    )
    created_timestamp: int = Field(
        ...,
        description="The timestamp in milliseconds when the chunk is created.",
    )


def is_assistant_text_message(message: ChatCompletionMessage) -> bool:
    return message.role == ChatCompletionRole.assistant and message.function_calls is None


def is_assistant_function_calls_message(message: ChatCompletionMessage) -> bool:
    return message.role == ChatCompletionRole.assistant and message.function_calls is not None
